Skip to content

Conversation

@ngxson
Copy link
Collaborator

@ngxson ngxson commented Oct 28, 2025

Supersede #16822

Fix #13694 (hopefully this time for real)

The idea is to store the M-RoPE (x,y,t) positions inside KV cells. This will allow the causal mask to be constructed correctly based on (x,y,t) positions.

The benefit is that this introduce no breaking changes, compared to other proposals.


This should now give the same output as #16822 across multiple values of -b

./build/bin/llama-mtmd-cli \
    -m "../models/Qwen2.5-VL-7B-Instruct-Q4_K_M.gguf" \
    --mmproj "../models/mmproj-Qwen2.5-VL-7B-Instruct-Q8_0.gguf" \
    --image "../models/0_bbox.png" \
    -p "Please first output bbox coordinates and colors of every rectangle in this image in JSON format, and then answer how many rectangles are there in the image." \
    --temp 0 -n 128
[
        {"bbox_2d": [168, 679, 462, 837], "color": "red"},
        {"bbox_2d": [312, 575, 480, 765], "color": "green"},
        {"bbox_2d": [601, 708, 672, 775], "color": "black"}
]

Image draw using this script: https://gist.github.com/ngxson/039024fb2bdaf2e3c15db702f9fddaff

image

TODO:

  • fix KV save/load with mrope --> follow-up PR
  • also fix the mtmd_image_tokens_get_n_pos

@ngxson
Copy link
Collaborator Author

ngxson commented Oct 28, 2025

@ggerganov Could you have a quick look to see if this is indeed better than the other proposals? Thanks!

@ggerganov
Copy link
Member

Looking good on first look. Will take a detailed look tomorrow.

@FMayran
Copy link

FMayran commented Oct 28, 2025

I think @rujialiu and myself thought of this possibility, and it really is the best solutions, as it stores everything state-related in the kV cache. With the price of increased kV cell size. Perhaps we only need to store the positional encoding of the last inserted token of a sequence as a property of the kV cache, not one position per cell?
Edit : unless we later want to remove part of the kV cache, in which case we need the per-cell positional encoding of course

@ngxson
Copy link
Collaborator Author

ngxson commented Oct 28, 2025

Perhaps we only need to store the positional encoding of the last inserted token of a sequence as a property of the kV cache, not one position per cell?

This could potentially work, but it will make the code to be more error-prone and more complicated to understand.

Storing (x,y) position per-cell is a more robust solution, with a bit of memory cost . Even with 128k tokens, this only use an additional of 2 (number of int) * 4 (bytes per int) * (128*1024) = 1048576 bytes = 1 MB

@rujialiu
Copy link

I think @rujialiu and myself thought of this possibility, and it really is the best solutions, as it stores everything state-related in the kV cache. With the price of increased kV cell size. Perhaps we only need to store the positional encoding of the last inserted token of a sequence as a property of the kV cache, not one position per cell? Edit : unless we later want to remove part of the kV cache, in which case we need the per-cell positional encoding of course

Yes, I've thought of this as well, but I'm not confident enough to ensure nothing is broken 😄

Just one concern: I see in llama-batch, has_mrope() only checks pos.size() vs token.size() (I know, there's nothing else we can do, without some ugly tricks), so essentially we can only support "one kind of mrope". I don't know whether this will cause trouble later. That's why I didn't go further with this approach because I want to minimize the model dependent multi-modal logic inside llama-batch.

But overall, this is the best solution so far.

@rujialiu
Copy link

Storing (x,y) position per-cell is a more robust solution, with a bit of memory cost . Even with 128k tokens, this only use an additional of 2 (number of int) * 4 (bytes per int) * (128*1024) = 1048576 bytes = 1 MB

I would definitely buy robustness and code simplicity with 1MB/128k tokens 😄

@rujialiu
Copy link

Though it's not implemented yet (but people requested it #16186 ), should we (conceptually) check this approach can support Qwen3-Omni's mrope (called TM-RoPe in its technical report) without much additional effort?

Comment on lines 20 to 23
bool has_mrope() const {
return data->pos.size() == data->token.size()*4;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can make this multi-dimensional positional information more decoupled from the concept of rope:

diff --git a/src/llama-batch.h b/src/llama-batch.h
index 34f964ef0..8a6c6daff 100644
--- a/src/llama-batch.h
+++ b/src/llama-batch.h
@@ -17,8 +17,13 @@ struct llama_ubatch {
         return b_equal_seqs != 0;
     }
 
-    bool has_mrope() const {
-        return data->pos.size() == data->token.size()*4;
+    // typical for M-RoPE cases:
+    //   0 - sequantial position of the tokens/embeddings in the sequence
+    //   1 - x position in the image
+    //   2 - y position in the image
+    //   3 - other
+    bool is_pos_2d() const {
+        return n_pos >= 3;
     }
 
     uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
@@ -29,6 +34,7 @@ struct llama_ubatch {
     uint32_t n_seq_tokens; // tokens per sequence set
     uint32_t n_seqs;       // sequence sets in the ubatch
     uint32_t n_seqs_unq;   // unique sequence ids in the ubatch
+    uint32_t n_pos;        // position inputs for each token/embedding
 
     // seq_id_unq: unique sequence ids in the ubatch
     // seq_idx:    indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
@@ -37,7 +43,7 @@ struct llama_ubatch {
     //                          // size               | idx | val
     llama_token  *  token;      // [n_tokens]         | i   | id, token
     float        *  embd;       // [n_embd, n_tokens] | i   | embd
-    llama_pos    *  pos;        // [n_tokens]         | i   | pos
+    llama_pos    *  pos;        // [n_tokens*n_pos]   | i   | pos
     int32_t      *  n_seq_id;   // [n_tokens]         | i   | -
     llama_seq_id ** seq_id;     // [n_tokens]         | s   | s0, s1, seq_id
     llama_seq_id *  seq_id_unq; // [n_seqs_unq]       | s   | seq_id

Comment on lines 12 to 20
struct llama_kv_pos_mrope {
llama_pos y = 0;
llama_pos x = 0;
// return true if this position is greater than the other position
bool is_gt(const llama_kv_pos_mrope & other) const {
return (y > other.y) || (y == other.y && x > other.x);
}
};

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, I think we can decouple the concept of M-RoPE here by declaring this struct to be more generic:

struct llama_kv_cell_ext {
    // 2D spatial positions, typically used for M-RoPE
    llama_pos x = 0;
    llama_pos y = 0;

    // ... maybe more data in the future
};

Comment on lines 451 to 453
// stores addition info for M-RoPE positions
std::vector<llama_kv_pos_mrope> pos_mrope;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// stores addition info for M-RoPE positions
std::vector<llama_kv_pos_mrope> pos_mrope;
// stores extra optional cell info
std::vector<llama_kv_cell_ext> ext;

Comment on lines 1254 to 1259
llama_kv_pos_mrope p1_mrope;
if (ubatch->has_mrope()) {
p1_mrope.y = ubatch->pos[i + ubatch->n_tokens];
p1_mrope.x = ubatch->pos[i + ubatch->n_tokens*2];
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it confusing to have the order of the positions as y, x. It's more canonical to have the dimensions ordered by increasing significance - x, y, z, .... This is also inline with the ggml convention for indexing.

I now notice that even the implementation of ggml_rope_multi uses this order. I would recommend to update this across the codebase for consistency. Even though it's a breaking change, it's better to do it now, before the mtmd stuff gets more adopted.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I agree that we should fix the ordering in ggml, I will make a PR for that

Copy link
Collaborator Author

@ngxson ngxson Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm on second thought, I think it cannot be ordered as x,y,z. This is because the full 4D position will be p,x,y,z with p the traditional LLM position

Because Qwen doesn't use the last z dim, so the ordering is currently p,y,x which is decreasing significant.

I think the better way is as you suggest above, decouple the logic into 2d_mrope to be more specific

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, not sure I follow. My point is that p,x,y,t is more consistent order compared to the current p,y,x,t.

@broadbit-hu
Copy link

broadbit-hu commented Oct 29, 2025

@ngxson Are we still sure we need this hparams.image_size 1024-size limitation? Scaling distorts the image, not just the relative positions.

                case PROJECTOR_TYPE_QWEN2VL:
                    {
                        // max image size = sqrt(max_pixels) = 3584
                        // ref: https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/blob/main/preprocessor_config.json
                        // however, the model use unreasonable memory past 1024 size, we force it to 1024 otherwise it's unusable
                        // ref: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/discussions/10
                        hparams.image_size = 1024;
                        hparams.warmup_image_size = hparams.patch_size * 8;
                    } break;
                case PROJECTOR_TYPE_QWEN25VL:
                    {
                        // max image size = sqrt(max_pixels)
                        // https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct/blob/main/preprocessor_config.json
                        // however, the model use unreasonable memory past 1024 size, we force it to 1024 otherwise it's unusable
                        // ref: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/discussions/10
                        hparams.image_size = 1024;
                        hparams.warmup_image_size = hparams.patch_size * 8;
                        get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
                    } break;

See the differences between hparams.image_size (64x2x14=) 1792 (left) and 1024 (right):

Handwritten-2025-10-29

@ngxson
Copy link
Collaborator Author

ngxson commented Oct 29, 2025

@broadbit-hu the limit was added because some users reported out of memory issue. We will implement custom image size limit in near future, which should fix the issue.

@broadbit-hu
Copy link

broadbit-hu commented Oct 29, 2025

@broadbit-hu the limit was added because some users reported out of memory issue. We will implement custom image size limit in near future, which should fix the issue.

@ngxson Yes, I also read the Hugging Face comment, so I ran tests with @FMayran 's llama.cpp version using various image_size values. If we want to process an image of a letter-sized A4 page with small printed text, a 1024-pixel height results in very poor quality. Increasing 1024 to 1792 significantly improved the results, and I didn't encounter any memory issues. I'll also run tests with even larger values.

Another problem is that 1024 is not divisible by 28 (Qwen 2.5 VL: 28, Qwen 3 VL: 32), so if we really want a smaller value, 1008 or 1120 would be better.

[EDIT] I performed tests with 1792x1792-sized images too and successfully reproduced the mentioned memory issue on both AMD GPUs and CPUs. I reduced the image_size value to 1568, and now the tests are running.

@ngxson ngxson marked this pull request as ready for review October 29, 2025 10:25
@ngxson ngxson requested a review from ggerganov October 29, 2025 10:25
@ngxson
Copy link
Collaborator Author

ngxson commented Oct 29, 2025

The save/load seems to be tricky as apply_ubatch() call inside state_read_meta only take the 1-dim position input. I think it's safer to implement save/load of llama_kv_cell_ext in a follow-up PR

@broadbit-hu
Copy link

broadbit-hu commented Oct 29, 2025

@ngxson I also reviewed your rectangles-test with my local FMayran's PR - compared to the original 1024x1024 resolution (right), manually scaling to 1008x1008 (left) significantly better results, just like 980x980 does. However, for this image, the 1120x1120 resolution does not improve anything but rather degrades positional accuracy.

Rectangles-test-20251029-1

Prompt:

Please first output bbox coordinates and colors of every rectangle in this image in JSON bbox2d format.

@theo77186
Copy link

The whole 1024px limitation will be moot as soon as flash attention will be implemented for multimodal models (related: #16837), this will avoid excessive VRAM usage caused by large image sizes.

@ngxson
Copy link
Collaborator Author

ngxson commented Oct 29, 2025

For discussion regarding image sizes, please open a dedicated issue to discuss. It is not related to the current PR.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The save/load seems to be tricky as apply_ubatch() call inside state_read_meta only take the 1-dim position input. I think it's safer to implement save/load of llama_kv_cell_ext in a follow-up PR

Add a TODO with a reference to this PR to not forget about this.

Comment on lines +1731 to 1733
// TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
apply_ubatch(sinfo, ubatch);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this statement - the apply_ubatch() does handle ext:

if (ubatch.is_pos_2d()) {
llama_kv_cell_ext ext {
/*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
/*.y =*/ ubatch.pos[i + ubatch.n_tokens],
};
cells.ext_set(idx, std::move(ext));
}

Copy link
Collaborator Author

@ngxson ngxson Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the ext is constructed from pos, but ideally what I want is that apply_ubatch take the raw ext read from the save file.

The benefit is that when ext has more info than just x,y, then we won't need to update the save/load code again.

Another approach could be the other way: on saving the state, we "serialize" ext back into list of pos that can be later feed into ubatch. But IMO this is a bit hacky.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed a few more changes bed0f57

Should be good to merge now.

@ngxson ngxson merged commit e3af556 into ggml-org:master Oct 29, 2025
62 of 63 checks passed
@easyfab
Copy link

easyfab commented Oct 30, 2025

Hi since this commit, with LightOnOCR-1B-1025-Q8_0.gguf and blank page I got infinite /n/n/n......
I saw it when ocr using python script and pdfs with blank page.
command line : llama-server -m ../models/LightOnOCR-1B-1025-Q8_0.gguf --mmproj ../models/mmproj/mmproj-LightOnOCR-1B-1025-Q8_0.gguf -c 8192 --jinja

And confirmation with webui:
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Eval bug: Qwen2.5-VL-7B-Instruct returns extremely inaccurate bbox coordinates

7 participants